#ifndef __GRAPH_HPP__
#define __GRAPH_HPP__

#include <iostream>
#include <unordered_map>
#include <vector>
#include <queue>
#include <stack>
#include <climits>
#include "../UnionFindSet/UnionFindSet.hpp"
using namespace std;

namespace clx_datastructure {
    /*
        V : 顶点
        W : 权值
    */
    template<class V, class W, W DFLT_WEIGHT = INT_MIN, bool Direction = false>
    class graph {
        public:

        graph(vector<V>& vertexs_) : vertexs(vertexs_),
            matrix(vertexs_.size(), vector<W>(vertexs_.size(), DFLT_WEIGHT)){
            for (int i = 0; i < vertexs.size(); i++) {
              vIndexMap[vertexs[i]] = i;
            }
        }
        int getVertexIndex(const V& v) {
            if (vIndexMap.find(v) == vIndexMap.end()) {
                throw invalid_argument("不存在的顶点");
                return -1;
            } else {
                return vIndexMap[v];
            }
        }

        void addEdge(const V& src, const V& dst, const W& weight) {
            int srci = getVertexIndex(src), dsti = getVertexIndex(dst);
            if (srci == dsti) return;
            matrix[srci][dsti] = weight;
            if (Direction == false) {
                matrix[dsti][srci] = weight;
            }
        }

        void bfs(const V& v) {
            if (vIndexMap.find(v) == vIndexMap.end()) 
                return;
            
            vector<bool> signs(vertexs.size(), false);
            queue<int> q;
            q.push(getVertexIndex(v));
            signs[getVertexIndex(v)] = true;;
            while (!q.empty()) {
                int count = q.size();
                for (int i = 0; i < count; i++) {
                    int src = q.front();
                    cout << vertexs[src] << " ";
                    q.pop();

                    for (int i = 0; i < vertexs.size(); i++) {
                        if (matrix[src][i] != DFLT_WEIGHT && signs[i] == false) {
                            q.push(i);
                            signs[i] = true;
                        }
                    }
                }
                cout << endl;
            }
        }

        void dfs(int index, vector<bool>& signs) {
            signs[index] = true;
            cout << vertexs[index] << " ";
            for (int i = 0; i < vertexs.size(); i++) {
                if (matrix[index][i] != DFLT_WEIGHT && signs[i] == false) {
                    dfs(i, signs);
                }
            }
        }

        void dfs(const V& v) {
            if (vIndexMap.find(v) == vIndexMap.end()) return;
            vector<bool> signs(vertexs.size(), false);
            dfs(getVertexIndex(v), signs);
        }

        // 克鲁斯卡尔算法
        W min_create_tree_kruskal(graph<V, W>& g) {
            // 目前只针对无向图
            int n = vertexs.size();
            int result = 0;
            priority_queue<pair<W, pair<int, int>>, vector<pair<W, pair<int, int>>>, greater<pair<W, pair<int, int>>>> q;
            UnionFindSet<V> ufs(vertexs);
            for (int i = 0; i < n; i++) {
                for (int j = i + 1; j < n; j++) {
                  if (matrix[i][j] != DFLT_WEIGHT) 
                    q.push({matrix[i][j], {i, j}});
                }
            }

            int count = 0;
            while (count < n - 1 && !q.empty()) {
                W weight = q.top().first;
                int i = q.top().second.first, j = q.top().second.second;
                q.pop();

                if (!ufs.inSameSet(vertexs[i], vertexs[j])) {
                    ufs.unionSet(vertexs[i], vertexs[j]);
                    g.addEdge(vertexs[i], vertexs[j], weight);
                    count += 1;
                    result += weight;
                }
            }
            return count == n - 1 ? result : -1;
        }

        // 普利姆算法
        W min_create_tree_prim(graph<V, W>& g) {
          int n = vertexs.size();
          vector<bool> book(n, false);
          int cur_index = 0, count = 0;
          book[0] = true;
          W result = 0;
          priority_queue<pair<W, pair<int, int>>, vector<pair<W, pair<int, int>>>, greater<pair<W, pair<int, int>>>> q;
          while (count < n - 1 && cur_index >= 0) {
            for (int i = 0; i < n; i++){
              if (!book[i] && matrix[cur_index][i] != DFLT_WEIGHT){
                q.push({matrix[cur_index][i], {cur_index, i}});
              }
            }

            while (book[q.top().second.second]) q.pop();
            if (q.empty()) break;

            W weight = q.top().first;
            cur_index = q.top().second.second;
            q.pop();

            g.addEdge(vertexs[q.top().second.first], vertexs[cur_index], weight);
            result += weight;
            book[cur_index] = true;
            count++;
          }
          return count == n - 1 ? result :-1;
        }

        void printMatrix() {
            int n = vertexs.size();
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    if (matrix[i][j] == DFLT_WEIGHT) cout << "0 ";
                    else cout << matrix[i][j] << " ";
                }
                cout << endl;
            }
        }

        void print_path(vector<int>& parent, int src, int dest){
          stack<int> st;
          st.push(dest);
          while (dest != src) {
            st.push(parent[dest]);
            dest = parent[dest];
          }

          while (!st.empty()) {
            cout << vertexs[st.top()] << "->" ;
            st.pop();
          }
          cout << endl;
        }

        // 迪杰斯特拉算法
        W Dijkstra(const V& src, const V& dest) {
          int n = vertexs.size(), srci = vIndexMap[src], desti = vIndexMap[dest];
          vector<int> parent(n, -1);
          vector<int> min_arrive(n, INT_MAX);
          vector<bool> book(n, false);

          min_arrive[srci] = 0;

          for (int i = 0; i < n; i++) {
            int minW = INT_MAX;
            int u = -1;
            for (int j = 0; j < n; j++){
              if (!book[j] && min_arrive[j] < minW) {
                minW = min_arrive[j];
                u = j;
              }
            }
            if (u == -1) break;
            book[u] = true;

            for (int j = 0; j < n; j++){
              if (!book[j] && matrix[u][j] != DFLT_WEIGHT && min_arrive[u] + matrix[u][j] < min_arrive[j]) {
                min_arrive[j] = min_arrive[u] + matrix[u][j];
                parent[j] = u;
              }
            }
          }
          if (min_arrive[desti] == INT_MAX) return -1;
          else {
            print_path(parent, srci, desti);
            return min_arrive[desti];
          }
        }


        // 贝尔福特算法
        int Bellman_Ford(const V& src, const V& dest) {
          int n = vertexs.size(), srci = vIndexMap[src], desti = vIndexMap[dest];
          vector<int> parent(n, -1);
          vector<int> min_arrive(n, INT_MAX);
          min_arrive[srci] = 0;

          int count = 0;
          while (count < n) {

            bool flag = false;
            for (int i = 0; i < n; i++) {
              for (int j = 0; j < n; j++) {
                if (min_arrive[i] != INT_MAX && matrix[i][j] != DFLT_WEIGHT && min_arrive[i] + matrix[i][j] < min_arrive[j]) {
                  min_arrive[j] = min_arrive[i] + matrix[i][j];
                  parent[j] = i;
                  flag = true;
                }
              }
            }
            if (!flag) break; 
          }
          if (count != n) {
            print_path(parent, srci, desti);
          }
          return count == n ? -1 : min_arrive[desti];
        }

        // 多源最短路径 佛洛伊德算法
        
        template<W MAX_W = INT_MAX>
        void FloydWarshall(vector<vector<W>>& vvDest, vector<vector<int>>& vvParent) {
          int n = vertexs.size();
          vvDest.resize(n, vector<W>(n, MAX_W));
          vvParent.resize(n, vector<int>(n, -1));

          for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
              if (matrix[i][j] != DFLT_WEIGHT) {
                vvDest[i][j] = matrix[i][j];
                vvParent[i][j] = i;
              }
            }
          }

          for (int k = 0; k < n; k++) {
            for (int i = 0; i < n; i++) {
              for (int j = 0; j < n; j++) {
                if (vvDest[i][k] != MAX_W && vvDest[k][j] != MAX_W && vvDest[i][j] > vvDest[i][k] + vvDest[k][j]) {
                  vvDest[i][j] = vvDest[i][k] + vvDest[k][j];
                  vvParent[i][j] = vvParent[k][j];
                }
              }
            }
          }
        }
      
      private:
        vector<V> vertexs;                // 顶点集合
        unordered_map<V, int> vIndexMap;  // 顶点下标映射
        vector<vector<W>> matrix;         // 邻接矩阵
    };
}

#endif

